import time
import requests
from quant.utils import Saving, Timer2, ValueUpdater, abstract_method, logging, catch_exception
from quant.const import Exchange, Event



























class _PrecisionSummary:  # todo 也可以改限频，单symbol限频 改为 全局限频
    def __init__(self, name, get_fn):
        self.get_fn = get_fn
        self.saving = Saving(name)
        self.last_time = time.time() - 30

        if not _quick_mode:
            self._update()

    def get_precision(self, symbol):
        try:
            return self.saving.get(symbol)
        except KeyError:
            pass

        if time.time() - self.last_time < 30:
            logging.warn('get_precision({}), too frequently, refused'.format(symbol))
            return None, None

        logging.warn('get precision key error: {}, update()'.format(symbol))
        self._update()

        try:
            return self.saving.get(symbol)
        except KeyError:
            err = 'Cannot get precision of {}'.format(symbol)
            logging.error(err)
            return None, None

    def _update(self):
        self.last_time = time.time()
        prec = _try_get(self.get_fn, 3)
        if prec:
            self.saving.override(prec)


class _ValueSummary:
    def __init__(self, *func):
        self.saving = Saving('value_summary')
        self._funcs = func
        self.last_time = 0

        if not _quick_mode:
            self._update()
        Timer2(self._update, _value_update_interval).start()

    def asset_value(self, asset):
        try:
            return self.saving.get(asset)
        except KeyError:
            pass

        logging.warn('get asset value error: {}, update()'.format(asset))
        self._update()
        try:
            return self.saving.get(asset)
        except KeyError:
            return None

    def _update_old(self):
        self.last_time = time.time()

        all_tickers = []
        for fn in self._funcs:
            t = _try_get(fn, 3)
            if t:
                all_tickers.append(t)

        all_symbols = []
        for t in all_tickers:
            all_symbols += list(t)
        all_symbols = set(all_symbols)

        _usdt = get_asset_value('usdt')
        result = {}
        for symbol in all_symbols:

            if 'yfii' in symbol:
                print(symbol)
                for tickers in all_tickers:
                    try:
                        print(tickers[symbol]['price'])
                    except:
                        pass

            asset, money = symbol.split('/')
            money = money.split('.')[0]
            if money != 'usdt' and money != 'busd':
                continue

            li = []
            for tickers in all_tickers:
                try:
                    value = tickers[symbol]['price'] * _usdt
                    if value == 0:
                        continue
                    li.append(value)
                except KeyError:
                    pass
            result[asset] = sum(li) / len(li)
        self.saving.update(result)

    def _update(self, for_symbol=None):
        now = time.time()
        if now - self.last_time < 3:
            logging.warn('get_asset_value({}) too frequently, refused'.format(for_symbol))
            return None
        self.last_time = now

        _usdt = get_asset_value('usdt')
        asset_to_list = {}
        for fn in self._funcs:
            tickers = _try_get(fn, 3)
            if tickers is None:
                continue

            for sy, dic in tickers.items():
                asset, money = sy.split('/')
                money = money.split('.')[0]

                if money != 'usdt' and money != 'busd':
                    continue

                price = dic['price']
                if price == 0:
                    continue
                value = price * _usdt

                li = asset_to_list.setdefault(asset, [])
                li.append(value)

        result = {asset: sum(li) / len(li) for asset, li in asset_to_list.items()}
        self.saving.update(result)
        return result


def _try_get(get_fn, count):
    for i in range(count):
        try:
            result = get_fn()
            return result
        except:
            catch_exception()


def _get_prec_instance(exchange):
    try:
        return _prec_instance_map[exchange]
    except KeyError:
        print('create prec')
        query = get_func_instance(exchange)
        get_fn = query.get_all_precision
        save_name = 'precision_{}'.format(exchange)
        ins = _PrecisionSummary(save_name, get_fn)
        _prec_instance_map[exchange] = ins
        return ins


def _get_size_instance(exchange):
    if exchange not in _size_instance_map:
        logging.info('create size instance: {}'.format(exchange))
        function_instance = get_func_instance(exchange)
        value_instance = ValueUpdater('{}_contract_size'.format(exchange), function_instance.get_all_contract_size, _value_update_interval)
        _size_instance_map[exchange] = value_instance

    result: ValueUpdater = _size_instance_map[exchange]
    return result


_func_instance_map = {}
_prec_instance_map = {}
_size_instance_map = {}
_value_map = {'usdt': 6.5, 'usd': 6.5, 'busd': 6.5}
_value_summary = None
_value_cal_exchanges = [Exchange.Binance, Exchange.Okex, Exchange.Huobi, Exchange.Gate, Exchange.Kucoin]
_value_update_interval = 1800
_quick_mode = False


class Functions:
    def __init__(self):
        self.session = requests.Session()
        # self.session = requests
        # print('self.session = requests')

    @abstract_method
    def get_all_ticker(self):
        symbol = ...
        result = {
            symbol:
                {
                    'price': 1,
                    'vol': 2,  # 24H的cny成交额.
                    'bid': 4,
                    'ask': 5
                    # 'value': 3,  # value，1数量的cny价值.  ############# 已经删除！！
                },
        }
        return result

    @abstract_method
    def get_all_precision(self):
        symbol = ...
        result = {symbol: ('0.01', '0.001'), }
        return result

    @abstract_method
    def get_value_factor(self, symbol):
        pass

    @abstract_method
    def get_all_contract_size(self):  # 1单位数量=多少单位个asset
        pass

    @abstract_method
    def get_recent_trade(self, symbol, limit=1000):
        result = ...
        return result  # [['id', 'side', 'price', 'amount', 'ts'], ...]

    @abstract_method
    def simulate_deal(self, order, deal_amount, mid_price, account):
        pass


def set_quick_mode():
    print('quick mode is set!')
    global _quick_mode
    _quick_mode = True
    ValueUpdater.set_quick_mode()


def set_asset_value(asset, value):
    _value_map[asset] = value


def get_func_instance(exchange):
    try:
        return _func_instance_map[exchange]
    except KeyError:
        from quant.exchanges import get_factory
        factory = get_factory(exchange)
        ins = factory.function_cls()
        _func_instance_map[exchange] = ins
    return ins


def get_asset_value(asset):
    try:
        return _value_map[asset]
    except KeyError:
        pass

    global _value_summary
    if _value_summary is None:
        fns = []
        for exchange in _value_cal_exchanges:
            funcs = get_func_instance(exchange)
            fns.append(funcs.get_all_ticker)
        _value_summary = _ValueSummary(*fns)

    result = _value_summary.asset_value(asset)
    return result


def get_precision(exchange, symbol):
    instance = _get_prec_instance(exchange)
    return instance.get_precision(symbol)


def get_value_factor(exchange, symbol):
    func_instance = get_func_instance(exchange)
    return func_instance.get_value_factor(symbol)


def get_contract_size(exchange, symbol):
    size_instance = _get_size_instance(exchange)
    return size_instance.get_value(symbol)


def has_symbol(exchange, symbol):
    prec = get_precision(exchange, symbol)
    if None in prec:
        return False
    return True


"""
尚未测试多个value_cal_exchanges, value_summary会不会取平均！
"""


if __name__ == '__main__':
    from quant.utils import set_test_mode as _set_test_mode
    _set_test_mode()
    set_quick_mode()

    def try_get_asset_value():
        from quant.utils import set_test_mode
        set_test_mode()
        while True:
            a = get_asset_value('xrp')
            print(a)
            input(':')

    def try_get_value_factor():
        a = get_value_factor('Binance', 'doge/usdt.swap')
        print(a)
        a = get_value_factor('Binance', 'doge/usdt.swap')
        print(a)
        a = get_value_factor('Binance', 'doge/usdt.swap')
        print(a)

    def try_get():
        e = 'Binance'
        # print(get_precision(e, 'btc/usdt.swap'))
        # print(get_precision(e, 'btc/usdt.swap'))
        # print(get_precision(e, 'btc/usdt'))
        # print(get_precision(e, 'btc/usdt'))
        #
        # print(get_asset_value('btc'))
        # print(get_asset_value('btc'))
        # print(get_asset_value('btc'))

        # print(get_value_factor(e, 'btc/usdt'))
        # print(get_value_factor(e, 'btc/usdt'))

    def perf_get_value_with_size():
        from quant.utils import perf_test

        def func():
            return get_value_factor('Huobi', 'matic/usdt.swap')

        print(func())
        for i in range(1000):
            perf_test(func)


    # fns = []
    # for exchange in _value_cal_exchanges:
    #     funcs = get_func_instance(exchange)
    #     fns.append(funcs.get_all_ticker)
    # vs = _ValueSummary(*fns)
    # print(vs)














